"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 4. Comparation between synthetic and the retrieved aerosol extinction profiles
import numpy as np
import subprocess
import shutil
import os
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import matplotlib
matplotlib.rcParams['mathtext.default'] = 'regular'

def e_profile(c, H, z):
    return c/H * np.exp(-z/H)

def box_profile(c, H, z):
    proflie = np.ones_like(z) * c/H
    proflie[z>H] = 0
    return proflie

def layered_profile(c, H1, H2, z):
    proflie = np.ones_like(z) * c/(H2-H1)
    proflie[z>H2] = 0
    proflie[z<=H1] = 0
    return proflie

def gauss_profile(c, x0, sigma, z):
    # 设置常数
    A = c/(sigma*np.sqrt(2*np.pi))  # 幅度
    return A * np.exp(-(z - x0) ** 2 / (2 * sigma ** 2))

def Generate_Profile(z1, z2, z3):
    aero_dict = {}
    aero_dict['aero0'] = 1e-10 * np.ones_like(z1)
    aero_dict['aero1'] = e_profile(0.25, 1, z1)
    aero_dict['aero2'] = e_profile(1, 1, z1)
    aero_dict['aero3'] = e_profile(0.25, 0.25, z1)
    aero_dict['aero4'] = box_profile(0.1, 0.2, z1)
    aero_dict['aero5'] = box_profile(0.5, 1,z1)
    aero_dict['aero6'] = gauss_profile(0.25, 1, 0.3, z1)
    aero_dict['aero7'] = box_profile(2, 0.2,z1)
    aero_dict['aero8'] = layered_profile(5, 1, 1.5, z1)
    aero_dict['aero9'] = layered_profile(5, 5, 5.5, z1)

    gas_dict = {}
    gas_dict['gas0'] = e_profile(1, 1, z2)
    gas_dict['gas1'] = e_profile(1, 0.25, z2)
    gas_dict['gas2'] = box_profile(1, 0.3, z2)
    gas_dict['gas3'] = box_profile(1, 1,z2)
    gas_dict['gas4'] = gauss_profile(1, 1.5, 0.3, z2)
    gas_dict['gas5'] = layered_profile(1, 1, 2, z2)

    return aero_dict, gas_dict
z1=np.r_[np.arange(0., 6.01, 0.1), np.arange(7, 10, 1), np.arange(10, 41, 6)]
z_grid_middle = np.r_[np.arange(0.01, 4, 0.01), np.arange(4, 10, 1), np.arange(10, 41, 5)]
z2 = (z_grid_middle[:-1] + z_grid_middle[1:]) / 2
z_grid = []
with open('RTM_grid_bak.dat', 'r') as f:
    for i in f:
        if i != '\n':
            z_grid.append(float(i.strip('\n')))
z3 = np.array(z_grid)
aero_dict, gas_dict = Generate_Profile(z1, z1, z3)

fig3, axs = plt.subplots(2, 5, figsize=(25, 10))
plt.subplots_adjust(top=0.95, bottom=0.1, right=0.90, left=0.05, hspace=0.12, wspace=0.)
plt.rcParams['font.size'] = 20
set_max = [0.008, 0.4, 1.2, 1.2, 0.8, 0.6, 0.4, 24, 12, 6]
for i in range(10):
    ax3 = axs.flatten()[i]
    ax3.set_xlim([0, set_max[i]])
    ax3.set_xticks([0, set_max[i] / 4, set_max[i] / 2, set_max[i] * 0.75])
    a = np.c_[aero_dict['aero{}'.format(i)], aero_dict['aero{}'.format(i)]].flatten()

    ax3.annotate('AERO{}'.format(i), xy=(0.5, 0.95), xycoords='axes fraction', fontsize=20,
                 horizontalalignment='center', verticalalignment='top')
    b = np.c_[z1[:-1], z1[1:]].flatten()

    prof = pd.read_csv('synthetic_retrieved_aerosol.dat', sep='\t').values.T[(1+9*i):(10++9*i)]
    z_grid = pd.read_csv('synthetic_retrieved_aerosol.dat', sep='\t').values.T[0]
    for j in range(9):
        ax3.plot(prof[j], z_grid, label='GEOM{}'.format(j), linewidth=3)
    ax3.plot(a[:-2], b, c='k', label='REAL', linewidth=3)

    ax3.set_yticks([0, 1, 2, 3, 4])
    ax3.set_ylim([0, 4])
    # if i==9:
    #     ax3.legend(bbox_to_anchor=(1.03, 0), loc=3, borderaxespad=0)
    if i not in [0, 5]:
        # ax3.set_yticks([])
        plt.setp(ax3.get_yticklabels(), visible=False)
    ax3.grid(linestyle='-.', linewidth=2)
    plt.setp(ax3.get_xticklabels(), fontsize=20)
    plt.setp(ax3.get_yticklabels(), fontsize=20)
    ax3.spines['bottom'].set_linewidth(3)
    ax3.spines['left'].set_linewidth(3)
    ax3.spines['right'].set_linewidth(3)
    ax3.spines['top'].set_linewidth(3)
    # if i not in [5,6, 7,8]:
    #     ax3.set_xticks([])
    #     plt.setp(ax3.get_xticklabels(), visible=False)
    # plt.legend(bbox_to_anchor=(0.9, 0.6), loc=3, borderaxespad=0)

ax3.legend(loc='center right', bbox_to_anchor=(1, 0.5), bbox_transform=plt.gcf().transFigure)

ax3 = plt.axes([0, 0, 1, 1])
ax3.axis('off')
ax3.annotate('Extinction [$km^{-1}$]', xy=(0.475, 0.05), xycoords='axes fraction',
             horizontalalignment='center', verticalalignment='center', fontsize=25)
ax3.annotate('Altitude [km]', xy=(0.02, 0.5), xycoords='axes fraction', rotation=90,
             horizontalalignment='center', verticalalignment='center', fontsize=25)
plt.savefig('temp/Figure4.eps')
plt.savefig('temp/Figure4.png')
plt.close()